{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gluon CIFAR-10 Trained in Local Mode\n",
    "_**ResNet model in Gluon trained locally in a notebook instance**_\n",
    "\n",
    "---\n",
    "\n",
    "---\n",
    "\n",
    "_This notebook was created and tested on an ml.p3.8xlarge notebook instance._\n",
    "\n",
    "## Setup\n",
    "\n",
    "Import libraries and set IAM role ARN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.mxnet import MXNet\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "role = sagemaker.get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install pre-requisites for local training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!/bin/bash setup.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Data\n",
    "\n",
    "We use the helper scripts to download CIFAR-10 training data and sample images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cifar10_utils import download_training_data\n",
    "download_training_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value `inputs` identifies the location -- we will use this later when we start the training job.\n",
    "\n",
    "Even though we are training within our notebook instance, we'll continue to use the S3 data location since it will allow us to easily transition to training in SageMaker's managed environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-gluon-cifar10')\n",
    "print('input spec (in this case, just an S3 path): {}'.format(inputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Script\n",
    "\n",
    "We need to provide a training script that can run on the SageMaker platform. When SageMaker calls your function, it will pass in arguments that describe the training environment. Check the script below to see how this works.\n",
    "\n",
    "The network itself is a pre-built version contained in the [Gluon Model Zoo](https://mxnet.incubator.apache.org/versions/master/api/python/gluon/model_zoo.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat 'cifar10.py'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Train (Local Mode)\n",
    "\n",
    "The ```MXNet``` estimator will create our training job. To switch from training in SageMaker's managed environment to training within a notebook instance, just set `train_instance_type` to `local_gpu`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = MXNet('cifar10.py',\n",
    "          role=role, \n",
    "          train_instance_count=1,\n",
    "          train_instance_type='local_gpu',\n",
    "          framework_version='1.1.0',\n",
    "          hyperparameters={'batch_size': 1024,\n",
    "                           'epochs': 50,\n",
    "                           'learning_rate': 0.1,\n",
    "                           'momentum': 0.9})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we've constructed our `MXNet` object, we can fit it using the data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our training script can simply read the data from disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "m.fit(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Host\n",
    "\n",
    "After training, we use the MXNet estimator object to deploy an endpoint. Because we trained locally, we'll also deploy the endpoint locally.  The predictor object returned by `deploy` lets us call the endpoint and perform inference on our sample images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = m.deploy(initial_instance_count=1, instance_type='local_gpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate\n",
    "\n",
    "We'll use these CIFAR-10 sample images to test the service:\n",
    "\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/airplane1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/automobile1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/bird1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/cat1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/deer1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/dog1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/frog1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/horse1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/ship1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/truck1.png\" />\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the CIFAR-10 samples, and convert them into format we can use with the prediction endpoint\n",
    "from cifar10_utils import read_images\n",
    "\n",
    "filenames = ['images/airplane1.png',\n",
    "             'images/automobile1.png',\n",
    "             'images/bird1.png',\n",
    "             'images/cat1.png',\n",
    "             'images/deer1.png',\n",
    "             'images/dog1.png',\n",
    "             'images/frog1.png',\n",
    "             'images/horse1.png',\n",
    "             'images/ship1.png',\n",
    "             'images/truck1.png']\n",
    "\n",
    "image_data = read_images(filenames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The predictor runs inference on our input data and returns the predicted class label (as a float value, so we convert to int for display)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, img in enumerate(image_data):\n",
    "    response = predictor.predict(img)\n",
    "    print('image {}: class: {}'.format(i, int(response)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Cleanup\n",
    "\n",
    "After you have finished with this example, remember to delete the prediction endpoint.  Only one local endpoint can be running at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p27",
   "language": "python",
   "name": "conda_mxnet_p27"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
